import matplotlib
matplotlib.use('Agg')

from train_and_eval_utils import *

from tensorflow.keras.applications import VGG16, VGG19, ResNet50, InceptionV3, Xception
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.preprocessing import image
from keras.applications import imagenet_utils
from keras.applications.inception_v3 import preprocess_input
import tensorflow as tf
from tensorflow import keras
from skimage.feature.peak import peak_local_max

import argparse
import ast
import cv2
import h5py
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scipy
import shutil
import sys
import time

from PIL import Image
from tqdm import tqdm

ap = argparse.ArgumentParser()
ap.add_argument("-model", "--model", type = str, default = "vgg16_cam")
ap.add_argument("-save_dir", "--save_dir", type = str, default = "keras_model")
ap.add_argument("-decay", "--decay", type = float, default = 1e-7)
ap.add_argument("-lr", "--lr", type = float, default = 2e-5)

ap.add_argument("-num_epochs", "--num_epochs", type = int, default = 8)
ap.add_argument("-batch_size", "--batch_size", type = int, default = 32)
ap.add_argument("-dropout_rate", "--dropout_rate", type = float, default = 0.6)
ap.add_argument("-checkpoint", "--checkpoint", type = int, default = 0)

ap.add_argument("-max_steps", "--max_steps", type = int, default = 1000)
ap.add_argument("-max_checkpoints", "--max_checkpoints", type = int, default = 1000)
ap.add_argument("-threshold", "--threshold", type = float, default = 0.5)
ap.add_argument("-buffer_size", "--buffer_size", type = int, default = 256)

ap.add_argument("-mode", "--mode", type = str, default = "train")
ap.add_argument("-train_filename", "--train_filename", type = str, default = "")
ap.add_argument("-test_filename", "--test_filename", type = str, default = "")
ap.add_argument("-results_csv", "--results_csv", type = str, default = "")

script_args = vars(ap.parse_args())

MODELS = {
    "vgg16": VGG16,
    "vgg16_cam": VGG16,
    "vgg19": VGG19,
    "inception": InceptionV3,
    "xception": Xception, # TensorFlow ONLY
    "resnet": ResNet50,
    "simple": "simple"
}

ZOOM_LEVEL = 17

def train(args, model, model_dir):
    print("model_dir: ", model_dir)
    print("num_epochs: ", args["num_epochs"])
    print("maximum number of checkpoints: ", args["max_checkpoints"])

    my_config = my_config = tf.estimator.RunConfig(
        save_checkpoints_secs = 10*60,
        keep_checkpoint_max = args["max_checkpoints"],
    )

    kiln_model = tf.keras.estimator.model_to_estimator(keras_model = model, model_dir = model_dir, config = my_config)

    input_name = model.input_names[0]
    train_images, train_labels, val_images, val_labels = load_data(args = args)

    train_spec = tf.estimator.TrainSpec(input_fn = lambda: imgs_input_fn(images = train_images,
                                                                         labels = train_labels,
                                                                         args = args,
                                                                         perform_shuffle = True,
                                                                         repeat_count = args["num_epochs"],
                                                                         batch_size = args["batch_size"],
                                                                         input_name = input_name,
                                                                         training = True),
                                        max_steps = args["max_steps"])
                                                                                                  
    eval_spec = tf.estimator.EvalSpec(input_fn = lambda: imgs_input_fn(images = val_images,
                                                                       labels = val_labels,
                                                                       args = args,
                                                                       perform_shuffle = False,
                                                                       repeat_count = 1,
                                                                       batch_size = 128,
                                                                       input_name = input_name,
                                                                       training = False))

    start_time = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)
    #tf.logging.set_verbosity(tf.logging.ERROR)

    tf.estimator.train_and_evaluate(kiln_model, train_spec, eval_spec)
    print("--- %s seconds ---" % (time.time() - start_time))


def evaluate(args, model, model_dir):
    tf.logging.set_verbosity(tf.logging.INFO)
    
    test_images, test_labels = load_data_helper(hdf5_file_name = args["test_filename"], train = False)

    input_name = model.input_names[0]
    kiln_model = tf.keras.estimator.model_to_estimator(keras_model = model, model_dir = model_dir)

    # load model saved checkpoint
    checkpoint_path = ""
    if args["checkpoint"] != 0:
        checkpoint_path = os.path.join(args["save_dir"], 'model.ckpt-' + str(args["checkpoint"]))

    start_time = time.time()

    results = []
    for result in tqdm(kiln_model.predict(lambda: imgs_input_fn(images = test_images,
                                                                labels = test_labels,
                                                                args = args,
                                                                perform_shuffle = False,
                                                                repeat_count = 1,
                                                                batch_size = 128,
                                                                input_name = input_name,
                                                                training = False),
                                          checkpoint_path = checkpoint_path)):
        results.append(result)

    prob = []
    prediction = []
    for i,x in enumerate(results):
        score = x['dense_1'][0]
        prob.append(score)
        if float(score) < args['threshold']:
            prediction.append(0)
        else:
            prediction.append(1)
            
    accuracy = calculate_accuracy(pred = prediction, label = test_labels)
    
    print()
    print("Model accuracy: ", accuracy)
    print()
            
    results_array = []
    for i in range(len(prob)):
        results_array.append([prob[i], prediction[i]])

    test_df = pd.DataFrame(results_array, columns = ["prob", "prediction"])

    results_csv_path = args['results_csv']
    if results_csv_path == "":
        results_csv_path = args['eval_csv'].split('.csv')[0] + '_results.csv'
        
    test_df.to_csv(results_csv_path, index=False)

    # Report some stat
    print "--- %s seconds to run inference on %d images ---" % (
        (time.time() - start_time), len(test_labels))
    print "--- Average of %s seconds per image ---" % (
        float(str(time.time() - start_time)) / len(test_labels))


if __name__ == '__main__':
    print_arguments(args = script_args)
    
    model = create_model(args = script_args)
    print model.summary()
    model_dir = os.path.join(os.getcwd(), script_args["save_dir"])
        
    if script_args["mode"] == "train":
        if os.path.exists(model_dir):
            shutil.rmtree(model_dir)
            os.makedirs(model_dir)
        train(args = script_args, model = model, model_dir = model_dir)
        
    elif script_args["mode"] == "test" or script_args["mode"] == "eval":
        evaluate(args = script_args, model = model, model_dir = model_dir)
        
    else:
        raise AssertionError("Given mode is not supported.")
